{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "from keras.datasets import cifar10\n",
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Activation, Flatten\n",
    "from keras.layers import Conv2D, MaxPooling2D\n",
    "import numpy as np\n",
    "\n",
    "import os\n",
    "import wandb\n",
    "from wandb.keras import WandbCallback\n",
    "from PIL import Image\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W&B Run: https://app.wandb.ai/qualcomm/Distillation/runs/lx9bmpjk\n",
      "Call `%%wandb` in the cell containing your training loop to display live results.\n"
     ]
    }
   ],
   "source": [
    "run = wandb.init(project=\"Distillation\",tensorboard=True)\n",
    "config = run.config\n",
    "config.dropout = 0.25\n",
    "config.dense_layer_nodes = 100\n",
    "config.learn_rate = 0.32  #0.01 \n",
    "config.batch_size = 1024  #32\n",
    "config.epochs = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_names = ['airplane','automobile','bird','cat','deer',\n",
    "               'dog','frog','horse','ship','truck']\n",
    "num_classes = len(class_names)\n",
    "\n",
    "(X_train, y_train_original_label), (X_test, y_test) = cifar10.load_data()\n",
    "\n",
    "# Convert class vectors to binary class matrices.\n",
    "y_train_original_label = keras.utils.to_categorical(y_train_original_label, num_classes)\n",
    "y_test = keras.utils.to_categorical(y_test, num_classes)\n",
    "\n",
    "# Normalize values into [0 to 1] because all default assume zero to one\n",
    "# We won't bother to fully sphere the data, removing median and divide by variance\n",
    "X_train = X_train.astype('float32') / 255.\n",
    "X_test = X_test.astype('float32') / 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train_softmax_output = np.load(\"y_train_softmax_output.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Emulation Model\n",
    "# Note: We can set some of these at variables to reduce capacity \n",
    "model = Sequential() # Sets up model structure to pass data forward \n",
    "model.add(Conv2D(32, # Activation layers \n",
    "                 (3, 3), # Conv dim\n",
    "                 padding='same', # Use padding to maintain same dims\n",
    "                 input_shape=X_train.shape[1:], activation='relu'))\n",
    "model.add(MaxPooling2D(pool_size=(2, 2))) # 16x16 dim per activation\n",
    "model.add(Dropout(config.dropout)) # Add a dash of regularization (too taste)\n",
    "\n",
    "# NOTE: Would be fun to try a 16x16xActivation conv to a 1x1x10 tensor\n",
    "model.add(Flatten()) # Because dense layers like 1D vectors \n",
    "model.add(Dense(config.dense_layer_nodes, activation='relu'))\n",
    "model.add(Dropout(config.dropout)) # A pinch more regularization \n",
    "model.add(Dense(num_classes, activation='softmax')) # Plate and serve\n",
    "\n",
    "# NOTE: Could also try Adam at some point\n",
    "opt = keras.optimizers.SGD(lr=config.learn_rate) # No momentum \n",
    "\n",
    "model.compile(loss='categorical_crossentropy', # Because one hot encoding\n",
    "              optimizer=opt, # Per previous decision to not use momentum\n",
    "              metrics=['accuracy']) # Easier to look at than cross entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAH3VJREFUeJztnVuMXNd1pv9Vt67qezf7QrJJiRJ1ieRYomRG0MiejB0jgWIEkQ0Ejv1g6MEIgyAGYiB5EDzA2APMgz0Y2/DDwAN6pEQZeHyJL7EQCEkcwYGQOFBEWbLukSiKMi/NZpPdze7qqq7rmocqTaj2/jdLvFRT2v8HEKw+q/Y56+w665w656+1lrk7hBDpkdlqB4QQW4OCX4hEUfALkSgKfiESRcEvRKIo+IVIFAW/EImi4BciURT8QiRK7lIGm9m9AL4GIAvgf7v7F2Pvz+fzPlAsBm2tVouOyyD8K8Ss8W0Vcvy8lo/YctkstZmFN2gWOYdGfGw2+T7HfneZjflIfrHZ9jbfVptvzTKRHYjQbof3LeZ7dH0R/y0yycyWifiRzfDPkx0DANCO/FrWYwcCGxNdX5illTWUKxs9beyig9/MsgD+J4DfBHAcwJNm9oi7v8jGDBSL2Hfn+4K2lZUluq2BTPiDnyzwyblm2yC1TU8OUdvU+DC1FbL54PLcQImOQZZP8dLyCrXVm3zfJsbHqC3TagSX12o1OmZjY4PaiqXwyRoAWuAnr0q1HFw+Nj5Kx8D5+uq1OrVlEf5cAH6yGRnmn/PQED8+8nk+H9WIjx67QGTCx0hsn5seju8vPfh9vp3Nm+35nb/MXQAOu/sRd68D+DaA+y5hfUKIPnIpwT8H4Nh5fx/vLhNCvAO4pHv+XjCzAwAOAMDAwMCV3pwQokcu5cp/AsDu8/7e1V32Ftz9oLvvd/f9uTy/NxNC9JdLCf4nAdxoZteZWQHAJwA8cnncEkJcaS76a7+7N83sMwD+Dh2p7yF3fyE2ZmNjAy+8GH7LypkzdNwkecBq2/iT16nWCLVZaYba1ttcdSi3wk/g3Qp0TGWDP7GtVPkT+EaLS1tnIhpnMRf2sdnk68uSp81A/FatsrFObc12eL9tYxsdk4mogI2IWlHK8eOgTJ6YL7WadMzgIH/abxn+7dWIGgQAiMiHlY2wQtNshJcDQDYX/lwaG1XuwyYu6Z7f3R8F8OilrEMIsTXoF35CJIqCX4hEUfALkSgKfiESRcEvRKJc8V/4nU8GQClHZKrIj/+uJZLenlme4DIzPUltpZiUE8naqtbCCTAbDS5DeWR9hVIkISiS2ONtvr2xyXBCU7PB11fIcz8iyZbIFviHVquH56rR5PMxGFlfboj7WIyMa1pYjsxEsgSbkQy8WCbp8BBPJiuvV6it0QxLerGEyrXVc8Hl7dgHtnn9Pb9TCPGuQsEvRKIo+IVIFAW/EImi4BciUfr6tN/MUbRwQsXICHflprmJ4PJtJZ4Jkm/z0lTlJZ5s02rz82G1EvY9w/N6MBopC5aLPKVeObfGx0U+tcmR8BPntVWehFOPJOhUSdIJEK9LN0xKYTXqPPEk0+I7lo8kGLVI6TIAyJHH87UaH1PI8w800+YJQbXyMrWBJIUBwAA5jJttrkicWw8rPq1IPcbN6MovRKIo+IVIFAW/EImi4BciURT8QiSKgl+IROmr1Jczw8RAeJOliJQzRpI6pkd5zbQWaRcFINJnBsjmIoXkSB22WjsiNUV0uVwkuaRV45KYZ/k5+/TpcBegVoPv9VqFJ51UWlwWHS5Fuu/USLsu8H3OGJepsgORTjnrXNYdzId9zEVaYW1E6i5WG1zqa0earK2UuY8rlfDxUybSMgBsNMLHQD1Sq3EzuvILkSgKfiESRcEvRKIo+IVIFAW/EImi4BciUS5J6jOzowDW0FHPmu6+P7qxrGF6PCzZjOS5xFYshm2ZLJdWSpH6eI0ml73akUw197AEVI/U22vVuQzY9kjGXERi8xzPOlurhzP0Wi0+v5VIa7BmxLa2zv0/sRT2I5/h6xst87lvnOLt3KrnuFR5zdQNweUzM7voGBsJ18cDgNryWWorl3l25Lk1LvWdOReWdY8e4360suHQrdW5PLiZy6Hzf8jd+ScjhLgq0dd+IRLlUoPfAfy9mT1lZgcuh0NCiP5wqV/7P+DuJ8xsBsCPzexld3/8/Dd0TwoHAKAYua8XQvSXS7ryu/uJ7v+nAfwQwF2B9xx09/3uvr+Q012GEFcLFx2NZjZkZiNvvgbwWwCev1yOCSGuLJfytX8WwA+77a1yAP6vu/9tbEA+l8XO6XBhx9EClyiGB8PSlkWkMkQyrCySTVerctkoQ2TAbSO8bdjQEM9GWz3HRZKxUZ4xtxYpqvnGifA6yzV+y1WIJILNDUayEvM88/Do2XB2Yc0jRVcjWX1joyPUds+tXGFenQ/Lul6JbGuKZ4vWKnw+ymV+LR3I83Xu3h7et5mZWTpmYTUsHZ595RQds5mLDn53PwLg9osdL4TYWnQTLkSiKPiFSBQFvxCJouAXIlEU/EIkSn8LeGYNkyPhbLtcPSwNAcBAPuzm4EC4Lx0A1KpcDmtE+q2Nj4f7AgKAk6KP9RY/hzYakeKSw7yP38nFcC82AHjtDZ7ttbgW3rdILUhcG+l5+NH/uI/adu3g/n/vqSPB5f9ymEtRzTbPZMxluDS3trJIbZVyeB5HRrj0hhbPLiwW+bgCyT4FgEHj45qt8Idzze6ddMzIUriX47Ov87nYjK78QiSKgl+IRFHwC5EoCn4hEkXBL0Si9Pdpfy6HmcltQVt1iT8Vz1jYzTJpcwQA1Ugts5xF6tlF2lqxM2W1wZ9Sj0/wBJ16iz/BPnL8JLUtrXIfWX2/bKTF12iRr28mF36qDADFJa5I3Di6Pbh8fpL7sbBymtpqFT7HT7/yCrVlSPuqxlCk1dgYT6hBhofM2BhXn0bakfZgpM6j11fpmD0kQW4g3/v1XFd+IRJFwS9Eoij4hUgUBb8QiaLgFyJRFPxCJEqfpb48Jqamg7aJYd5eK5MJJ0WsrC7TMY31Ml9fK9auixe0c5JgNDzM6/Q1wG0vHeES1XqNt34qFge4rRD2sTTEZaiJLJdFnzq8QG3NOj98amNhqW96gs+HgctvjSaXgit1XktwndTqqzf5PltEuo10c0M+E2n1lonULsyF57FZ41KqE5mY5J4F0ZVfiERR8AuRKAp+IRJFwS9Eoij4hUgUBb8QiXJBqc/MHgLwOwBOu/uvdpdNAvgOgD0AjgL4uLtz3e3f1wYQ2c4i7YwYA5F6aoMIZz0BQC5yzstkIvX4iAw4UOLtus6c4llxlTN8yq6f5JJYjateKBJJ7+a9c3RMJrLCZpbP8WpEas1lw3UGRwr8c9k2sZfa9t54DbW9/osnqe3lV04ElxdyERnNuUzcbPKQyZCMSgDIF/g8ttvh46od0RXNwsdpRIn8JXq58v8FgHs3LXsAwGPufiOAx7p/CyHeQVww+N39cQBLmxbfB+Dh7uuHAXz0MvslhLjCXOw9/6y7z3dfn0KnY68Q4h3EJT/w804xe/qjQjM7YGaHzOzQWiVysyqE6CsXG/wLZrYDALr/0/pL7n7Q3fe7+/6RQf4QSwjRXy42+B8BcH/39f0AfnR53BFC9ItepL5vAfgggCkzOw7g8wC+COC7ZvZpAG8A+HgvG2u7o7oRLlZoDZ6ZBYQzsNbXeYHDeoOf15oZ/g2kXOHS3Cqxze3m0+hNvr5rp7gws3cnl4YqG3zc3E23B5cXnN9yLZ/jhVBL4+GCqwCAszxTbff2HcHlK+s8W/H6X7mR2kYneFbi6MQt1La8GJ7/5XO85Vk+IkdmnGdUNtqRbFGeLIpWI3x8R5IEaeu4t5HUd+Hgd/dPEtOH38Z2hBBXGfqFnxCJouAXIlEU/EIkioJfiERR8AuRKH0t4OlwtCwsh3iLF1RkskapyIt+Do9waejkIpcVXz++SG25fNiPwgLvq7exwNd34wyX8z78QS57vXZic6rFvzMyFy6QOrUtXFATAE4v8iKd4+MR2avN/S+QgpWnF8NZdgCQK65Q2+LKPLWdmOdZePl8+DgYH+XaW7XKBTPP8eulRbS5dkQGzFh4nEUyTCNtHntGV34hEkXBL0SiKPiFSBQFvxCJouAXIlEU/EIkSl+lvmw2g/Hx4aCtmeNSX7kczkjzBpdPzq3xrK03fsGlrXKZy0alYvhcOf86zy6cLfKijnNz11Lb+M7rqC2/FkkRI0VNd91+Fx9yistvpSaXKlvgmYLr62HbjsGwFAkA9RbfLxsKHzcAsGtoJ7WNjIclzrWzp+iY0wtnqa1hXN7cqPOioMhwbW5oIJxlWq9GJExSENSIbBh0qed3CiHeVSj4hUgUBb8QiaLgFyJRFPxCJEpfn/a3W02srYSfpObqvNZdnrQmAi8hh1yWGytlrgRMjPBElvGh8FPZ6jJ/2j+zk9fAm7vtP1Hb88fr1PbKYW67Z8dkcPnKCh8zuzdc9w8AMqhQW73GlYBxDz+5Xz3Nn6SX6ryW4I7J8H4BwEqL19XL3zYRXF6NJAr986OPUNvxY3yfs5GWXLFGWiyPqBFrK9cIzxVLgguuo+d3CiHeVSj4hUgUBb8QiaLgFyJRFPxCJIqCX4hE6aVd10MAfgfAaXf/1e6yLwD4AwBv6h6fc/dHe9lgligerUgSgxOZJEPaeAFAy7jUt8wVJayuRuq31cJy2Y4xLg/+2oc+RG27br6b2n7w5w9R2/ZIkku2Hq5PeOLIa3x9199KbcVtN1DbkHN5trIU7t1aaoelNwCoV7mseGaN28aneRLUtu17gsur5VE6JsNNaBV4MlOshl+jwaVWa4YT1Mx54lqzGQ7dyy31/QWAewPLv+ru+7r/egp8IcTVwwWD390fB8DLxQoh3pFcyj3/Z8zsWTN7yMz4dzkhxFXJxQb/1wHsBbAPwDyAL7M3mtkBMztkZofKFX7fI4ToLxcV/O6+4O4td28D+AYAWibG3Q+6+3533z88yKvaCCH6y0UFv5ntOO/PjwF4/vK4I4ToF71Ifd8C8EEAU2Z2HMDnAXzQzPYBcABHAfxhLxszAEaUiBbJUgJ426JI5yR4NbK+SAm8yW28zdf2wbC0eOf+m+iYW+7hct7yaS5vDjR55uH1u3ZRW5vs3PYZXjuvucEl00okG7De5OMa1fCh1QKXKV87cZzannv+ELXdczf3cdv2cFbl6lpYigQA0uELADC1h8u67Vh7rXpEtiMS8rlF3r6sthZ2sk2yKUNcMPjd/ZOBxQ/2vAUhxFWJfuEnRKIo+IVIFAW/EImi4BciURT8QiRKXwt4ugNtksFUrXGJokCy2HI5XjAxm+Hyzw3b+a+RiyV+Ptxz7e7g8ts/wDP3dtx8G7U98y9/Tm3X7OY+bn/Pe6mtML03uDw3OEbHVDa45Fhd5Zl7CyePUdvyQli2azV4dl5pJFwgFQCmpvhnfezk09Q2u2MuuLxZiWSRVnnbLVtfpraWhzMqAcCZxg2gNBDet8J2vs+rAyTT9W1EtK78QiSKgl+IRFHwC5EoCn4hEkXBL0SiKPiFSJS+Sn1mhnw2vMnlSIHG1kZY1igNluiYbIZLKzORzL1j8zyTau+doVKGwK73hpd34JJdY22d2sZGuDQ3fdM+alvPhXvavfD0k3RMrcr9WF3l83HmxC+oLdsKS63FIj/k5q4Ly3IAcNtNvJBoM8sz7fLZ8fDyAs/6zG3wIp2VN05QG5OxAaAZucyWSV/JwW18v2ZJD8h8vvfrua78QiSKgl+IRFHwC5EoCn4hEkXBL0Si9Dexp91GrRp+kjo4wF2xYvhpaD7Da8h5i9tKw7yV1+/+/u9S2z2//eHg8tGpWTpm4chL1JaN+L+yxmv4LR79N2o7uRZ+4vyPf/3XdMxwiSeQbNR4Asz2Wa5IjI6En1S/fpwnA9Uj8zG5cw+13fTe91EbWgPBxUsrvF5ghahLALBc5T6a82N4o8oT18qkxZaXuepwS1jEQLv3bl268guRKgp+IRJFwS9Eoij4hUgUBb8QiaLgFyJRemnXtRvAXwKYRac910F3/5qZTQL4DoA96LTs+ri78wJnAByOtpPaem2eFGHNsEzS9EhLrkjNtOLAKLXtex+XjQbyYUnsxWd4Dbnlk69RW63GpZy15SVqO3b4RWorezjZKd/i2xrOcelztMiTS6YnuNQ3v3AquLwZactWWeOy4rHXeRIR8AK1lMvhGoTFHD8+mgMz1Ha2yY+dUonXIBwc4UlopVxYjlyrrNIxzXZYcnwbSl9PV/4mgD9191sB3A3gj83sVgAPAHjM3W8E8Fj3byHEO4QLBr+7z7v7z7qv1wC8BGAOwH0AHu6+7WEAH71STgohLj9v657fzPYAuAPAEwBm3X2+azqFzm2BEOIdQs/Bb2bDAL4P4LPu/pabEXd3kNsNMztgZofM7NB6ldfSF0L0l56C38zy6AT+N939B93FC2a2o2vfASDY8NzdD7r7fnffP1QqXA6fhRCXgQsGv5kZgAcBvOTuXznP9AiA+7uv7wfwo8vvnhDiStFLVt/7AXwKwHNm9kx32ecAfBHAd83s0wDeAPDxC6/KAYRlu3aT3xLk8uGae61IzbQ6ePbV7Bivq/d3j/wNtU3OhiWlmR3hNl4AUK/w7Lx8PizxAMDwEJeUchkuzQ0ROXL7TLjmGwBU17hCW8pyH88unqG2Rj382YwUueRVL3Op79WnD1Hb/MuvUFutSVpo5fkctmLzu4tLnxjix3BmgEutRSLbTYDP1S3vuS64vFQ8Qsds5oLB7+7/BIDlOIZzXIUQVz36hZ8QiaLgFyJRFPxCJIqCX4hEUfALkSh9LeAJN7TbYeGgEMksK+ZI8cMML7TokRZO7TrPLDtzJpyNBgDlxbCt1ODZV23w/Zqc4PLb+M5pamu2atR24mTYR4/ke2Uy/DCoN7lkmjVe+HOoGJZnSYJmZ30xYyRLs1XncmqGHG+rFS5v1geIPAhgZCef+/USb2221uYy4MZ6+Bq8bfR6OmaKSLe5fO8hrSu/EImi4BciURT8QiSKgl+IRFHwC5EoCn4hEqW/Uh8MGQtniRUHeAaTkwy9oVJYTgKAoZEpaqs0eIbVthFecyBH/KifW6Bj2hm+vkqeS1uzs+GsLQBo17lsdPNtu4LLf/qTx+iYuleoLW9cTq2W+bjRkXBWYiHHD7msRfrZbfDP7PV5LtutrIQ/s5qt0zHTN/Fr4tx4JCvR+We9fIbPVWEjLJkOzUUyMSvhrMl2RC3djK78QiSKgl+IRFHwC5EoCn4hEkXBL0Si9PVpf8aAQi58vqnUeMJElrSMakfqy1UaPDkjm+dJIgMF/jQ3nw/7URjkbavGRnmC0alFrhJU5sJP7QFgZvcN1HbidLiu3nt+7f10THnxJLUdeYW3wlov80SWXDY8/2NjvDahkfqOADB/gvv4izciiT0D4fkfneVK0fRkxMeI6mBL/LOeWOahNjczGVy+a5wfA4dfDCdw1ao8aW0zuvILkSgKfiESRcEvRKIo+IVIFAW/EImi4BciUS4o9ZnZbgB/iU4Lbgdw0N2/ZmZfAPAHABa7b/2cuz8a3VjOMDsdPt80zp6l46qtsAS0znMz4BneyisXSS4ZHeXJFAXSCqu6zmv4lWI11ercduinP6W262/mEuHx42EJKBOpdzg4wGvxZSNyaqnEpa31cljqq1a5BNuMtGwbLnE/7rnjJmorkgSjZpbXJmw1eBJO9RiX+jJrRWqbGRyhtjtuek94zDjvev/U/OvB5c0G36/N9KLzNwH8qbv/zMxGADxlZj/u2r7q7v+j560JIa4aeunVNw9gvvt6zcxeAjB3pR0TQlxZ3tY9v5ntAXAHgCe6iz5jZs+a2UNmxlvfCiGuOnoOfjMbBvB9AJ9191UAXwewF8A+dL4ZfJmMO2Bmh8zs0GqF39MJIfpLT8FvZnl0Av+b7v4DAHD3BXdvuXsbwDcA3BUa6+4H3X2/u+8fHeSVToQQ/eWCwW9mBuBBAC+5+1fOW77jvLd9DMDzl989IcSVopen/e8H8CkAz5nZM91lnwPwSTPbh478dxTAH15oRYWC4Zrd4av/mHGZ5PCxsPSysMiz8+otLg0ND/PdXq/wDLFWuxxcno2cQ5cWuYS5VuayzEaD+5F1bhsZDj96WTi1RMccX+fyVdu5RDg7zWVRa4ezy5ZXeL29gSH+mY2PcamskOXzX6sTyTfH5c31Gl9fvRxpUdbm427YvZ3adm4Pz+Ox41zSPbsYjolmrOXZJnp52v9PAEJHQFTTF0Jc3egXfkIkioJfiERR8AuRKAp+IRJFwS9EovS1gGc2ZxidIJlxRLoAgImZbNgwxIswnlngBUE3Iu2ucgVevJENazd4BmGjxf04V+Wy11Aki22jwqW56ka4gGc94mMrYnMncw+gvBpp1zUaLoQ6OsqLnVarfH1nzvK5Gh7m2YWWCV/frMll4kKOF3Ed4Io0CgU+V3tu2ENt1UrYl8cff5GOefaV0+F1bfSe1acrvxCJouAXIlEU/EIkioJfiERR8AuRKAp+IRKlr1KfmSFXDG+yOMpz/SeHw+eoXJXLaPkSz25ajfRNQ4ufD0vFmfCQPN9Wq8b72RUGuR/5HJ+PbJZLnDUP+1JvcHnTI5l7xhUxeJ1Lji1iykey6VDg8ubKMpf6qnXen25sPCzd5ogECACZyNxXwKW0hTNr1LYcyeBcWw9naf7DP77Mt0VU0Y26pD4hxAVQ8AuRKAp+IRJFwS9Eoij4hUgUBb8QidJXqa/dNpRZAcTsMB03PBTWjfIlrkMNRdKvxsa4NFde5b3kyqvhgorlSiSrb4PbRgq8AGaR9AUEgGaNS5y5XPh8Xoic5vMDPBvNjA8cjBRCzRBTs8WlqEIp0kNxnMubS0tcYlsj0ufoJJ/7SqRn4KtHeUHWl587Rm2zkzxbdHYX2bcMP06nSEHThTUue/7S6nt+pxDiXYWCX4hEUfALkSgKfiESRcEvRKJc8Gm/mRUBPA5goPv+77n7583sOgDfBrANwFMAPuXu0Ta89Tpw/I2wrbbCn86PTIefEBdLkYQOLh5gcpLvdnmd15FbWQnbls/yRJBl/nAY2TZ/yt52rmS0WlxBQDtsi53lLcMTe7I5PlfVSBKUk4f6edLGCwCaFd5SrBWp79eKJAutlMPjWBcvAFiKKD5HD/MPdOXsOrXV1/kGt4+FW3ndcu0cHcNcfPXUKh2zmV6u/DUAv+Hut6PTjvteM7sbwJcAfNXdbwCwDODTPW9VCLHlXDD4vcObHSrz3X8O4DcAfK+7/GEAH70iHgohrgg93fObWbbbofc0gB8DeA3Aivv//3J3HAD/jiKEuOroKfjdveXu+wDsAnAXgF/pdQNmdsDMDpnZoXNlXvxBCNFf3tbTfndfAfATAP8BwLiZvfk0aBeAE2TMQXff7+77x4YjHQ+EEH3lgsFvZtNmNt59XQLwmwBeQuck8Hvdt90P4EdXykkhxOWnl8SeHQAeNrMsOieL77r735jZiwC+bWb/DcDTAB680Irccmjlp4K2RmE/HVdrhxNZMs1wayoAKI5x+Wp8mn8DmcjwxJPJSjjRYmWJt3daOcPlvOo6n/5Wk8uHcH7ObjfDPm5U+S1XoRCpF5jj/q9t8MSTKrnFy0fU4JFMOFkFANoZLmE1GnweB4bCkmkxz+sFjhe4j9djnNreeztvG3bzbbdT254bbgguv+tuLm8eP1kOLv/n13hMbOaCwe/uzwK4I7D8CDr3/0KIdyD6hZ8QiaLgFyJRFPxCJIqCX4hEUfALkSjmkeyxy74xs0UAb+b1TQHoXZe4csiPtyI/3so7zY9r3X26lxX2NfjfsmGzQ+7OxX35IT/kxxX1Q1/7hUgUBb8QibKVwX9wC7d9PvLjrciPt/Ku9WPL7vmFEFuLvvYLkShbEvxmdq+Z/ZuZHTazB7bCh64fR83sOTN7xswO9XG7D5nZaTN7/rxlk2b2YzN7tfv/xBb58QUzO9Gdk2fM7CN98GO3mf3EzF40sxfM7E+6y/s6JxE/+jonZlY0s381s593/fiv3eXXmdkT3bj5jplFUj97wN37+g9AFp0yYNcDKAD4OYBb++1H15ejAKa2YLu/DuBOAM+ft+y/A3ig+/oBAF/aIj++AODP+jwfOwDc2X09AuAVALf2e04ifvR1TgAYgOHu6zyAJwDcDeC7AD7RXf6/APzRpWxnK678dwE47O5HvFPq+9sA7tsCP7YMd38cwOY61fehUwgV6FNBVOJH33H3eXf/Wff1GjrFYubQ5zmJ+NFXvMMVL5q7FcE/B+D8dqZbWfzTAfy9mT1lZge2yIc3mXX3+e7rUwBmt9CXz5jZs93bgit++3E+ZrYHnfoRT2AL52STH0Cf56QfRXNTf+D3AXe/E8BvA/hjM/v1rXYI6Jz50TkxbQVfB7AXnR4N8wC+3K8Nm9kwgO8D+Ky7v6V0Tz/nJOBH3+fEL6Fobq9sRfCfALD7vL9p8c8rjbuf6P5/GsAPsbWViRbMbAcAdP8/vRVOuPtC98BrA/gG+jQnZpZHJ+C+6e4/6C7u+5yE/NiqOelu+20Xze2VrQj+JwHc2H1yWQDwCQCP9NsJMxsys5E3XwP4LQDPx0ddUR5BpxAqsIUFUd8Mti4fQx/mxMwMnRqQL7n7V84z9XVOmB/9npO+Fc3t1xPMTU8zP4LOk9TXAPznLfLhenSUhp8DeKGffgD4FjpfHxvo3Lt9Gp2eh48BeBXAPwCY3CI//g+A5wA8i07w7eiDHx9A5yv9swCe6f77SL/nJOJHX+cEwG3oFMV9Fp0TzX8575j9VwCHAfwVgIFL2Y5+4SdEoqT+wE+IZFHwC5EoCn4hEkXBL0SiKPiFSBQFvxCJouAXIlEU/EIkyv8DgvpxjWxt2GcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data = X_train[0]\n",
    "from matplotlib import pyplot as plt\n",
    "plt.imshow(data, interpolation='nearest')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([0.04264229, 0.00540979, 0.07433245, 0.20201543, 0.07983524,\n",
       "       0.23577349, 0.24125147, 0.11060702, 0.00521307, 0.00291974],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model2 = keras.models.load_model(\"lukas.h5\")\n",
    "y_train_softmax_output = model2.predict(X_train)\n",
    "print(class_names)\n",
    "y_train_softmax_output[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: Wandb version 0.7.2 is available!  To upgrade, please run:\n",
      "wandb:  $ pip install wandb --upgrade\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 50000 samples, validate on 10000 samples\n",
      "Epoch 1/50\n",
      "50000/50000 [==============================] - 5s 110us/step - loss: 2.2905 - acc: 0.5324 - val_loss: 2.3056 - val_acc: 0.1059\n",
      "Epoch 2/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2870 - acc: 0.6352 - val_loss: 2.3029 - val_acc: 0.1084\n",
      "Epoch 3/50\n",
      "50000/50000 [==============================] - 3s 55us/step - loss: 2.2866 - acc: 0.6830 - val_loss: 2.3013 - val_acc: 0.1092\n",
      "Epoch 4/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2863 - acc: 0.7110 - val_loss: 2.3005 - val_acc: 0.1097\n",
      "Epoch 5/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2862 - acc: 0.7287 - val_loss: 2.2999 - val_acc: 0.1094\n",
      "Epoch 6/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2861 - acc: 0.7406 - val_loss: 2.2992 - val_acc: 0.1090\n",
      "Epoch 7/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2860 - acc: 0.7520 - val_loss: 2.2989 - val_acc: 0.1088\n",
      "Epoch 8/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2860 - acc: 0.7593 - val_loss: 2.2986 - val_acc: 0.1095\n",
      "Epoch 9/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2860 - acc: 0.7659 - val_loss: 2.2985 - val_acc: 0.1093\n",
      "Epoch 10/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2859 - acc: 0.7703 - val_loss: 2.2983 - val_acc: 0.1091\n",
      "Epoch 11/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2859 - acc: 0.7766 - val_loss: 2.2979 - val_acc: 0.1090\n",
      "Epoch 12/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2859 - acc: 0.7800 - val_loss: 2.2978 - val_acc: 0.1090\n",
      "Epoch 13/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2859 - acc: 0.7828 - val_loss: 2.2975 - val_acc: 0.1086\n",
      "Epoch 14/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2858 - acc: 0.7853 - val_loss: 2.2975 - val_acc: 0.1093\n",
      "Epoch 15/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2859 - acc: 0.7903 - val_loss: 2.2971 - val_acc: 0.1093\n",
      "Epoch 16/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2859 - acc: 0.7911 - val_loss: 2.2972 - val_acc: 0.1101\n",
      "Epoch 17/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2858 - acc: 0.7955 - val_loss: 2.2969 - val_acc: 0.1101\n",
      "Epoch 18/50\n",
      "50000/50000 [==============================] - 3s 60us/step - loss: 2.2859 - acc: 0.7956 - val_loss: 2.2969 - val_acc: 0.1108\n",
      "Epoch 19/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2858 - acc: 0.7963 - val_loss: 2.2966 - val_acc: 0.1104\n",
      "Epoch 20/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2858 - acc: 0.7982 - val_loss: 2.2968 - val_acc: 0.1115\n",
      "Epoch 21/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2858 - acc: 0.7999 - val_loss: 2.2964 - val_acc: 0.1114\n",
      "Epoch 22/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2858 - acc: 0.8020 - val_loss: 2.2968 - val_acc: 0.1116\n",
      "Epoch 23/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2858 - acc: 0.8034 - val_loss: 2.2962 - val_acc: 0.1117\n",
      "Epoch 24/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2858 - acc: 0.8043 - val_loss: 2.2964 - val_acc: 0.1119\n",
      "Epoch 25/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2857 - acc: 0.8070 - val_loss: 2.2959 - val_acc: 0.1119\n",
      "Epoch 26/50\n",
      "50000/50000 [==============================] - 3s 55us/step - loss: 2.2857 - acc: 0.8085 - val_loss: 2.2961 - val_acc: 0.1119\n",
      "Epoch 27/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2858 - acc: 0.8083 - val_loss: 2.2958 - val_acc: 0.1119\n",
      "Epoch 28/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2857 - acc: 0.8126 - val_loss: 2.2960 - val_acc: 0.1122\n",
      "Epoch 29/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2857 - acc: 0.8108 - val_loss: 2.2956 - val_acc: 0.1121\n",
      "Epoch 30/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2857 - acc: 0.8133 - val_loss: 2.2959 - val_acc: 0.1135\n",
      "Epoch 31/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2857 - acc: 0.8137 - val_loss: 2.2957 - val_acc: 0.1133\n",
      "Epoch 32/50\n",
      "50000/50000 [==============================] - 3s 59us/step - loss: 2.2857 - acc: 0.8174 - val_loss: 2.2953 - val_acc: 0.1131\n",
      "Epoch 33/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2857 - acc: 0.8155 - val_loss: 2.2959 - val_acc: 0.1138\n",
      "Epoch 34/50\n",
      "50000/50000 [==============================] - 3s 57us/step - loss: 2.2856 - acc: 0.8169 - val_loss: 2.2954 - val_acc: 0.1134\n",
      "Epoch 35/50\n",
      "50000/50000 [==============================] - 3s 58us/step - loss: 2.2857 - acc: 0.8188 - val_loss: 2.2966 - val_acc: 0.1147\n",
      "Epoch 36/50\n",
      "50000/50000 [==============================] - 3s 63us/step - loss: 2.2857 - acc: 0.8187 - val_loss: 2.2949 - val_acc: 0.1137\n",
      "Epoch 37/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2856 - acc: 0.8192 - val_loss: 2.2952 - val_acc: 0.1141\n",
      "Epoch 38/50\n",
      "50000/50000 [==============================] - 3s 55us/step - loss: 2.2856 - acc: 0.8213 - val_loss: 2.2947 - val_acc: 0.1141\n",
      "Epoch 39/50\n",
      "50000/50000 [==============================] - 3s 54us/step - loss: 2.2856 - acc: 0.8216 - val_loss: 2.2951 - val_acc: 0.1146\n",
      "Epoch 40/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2856 - acc: 0.8209 - val_loss: 2.2946 - val_acc: 0.1146\n",
      "Epoch 41/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2856 - acc: 0.8228 - val_loss: 2.2953 - val_acc: 0.1153\n",
      "Epoch 42/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2856 - acc: 0.8243 - val_loss: 2.2946 - val_acc: 0.1145\n",
      "Epoch 43/50\n",
      "50000/50000 [==============================] - 3s 54us/step - loss: 2.2856 - acc: 0.8233 - val_loss: 2.2955 - val_acc: 0.1155\n",
      "Epoch 44/50\n",
      "50000/50000 [==============================] - 3s 55us/step - loss: 2.2856 - acc: 0.8244 - val_loss: 2.2946 - val_acc: 0.1151\n",
      "Epoch 45/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2856 - acc: 0.8246 - val_loss: 2.2944 - val_acc: 0.1148\n",
      "Epoch 46/50\n",
      "50000/50000 [==============================] - 3s 56us/step - loss: 2.2856 - acc: 0.8245 - val_loss: 2.2951 - val_acc: 0.1159\n",
      "Epoch 47/50\n",
      "50000/50000 [==============================] - 3s 55us/step - loss: 2.2855 - acc: 0.8262 - val_loss: 2.2946 - val_acc: 0.1155\n",
      "Epoch 48/50\n",
      "50000/50000 [==============================] - 3s 55us/step - loss: 2.2856 - acc: 0.8269 - val_loss: 2.2944 - val_acc: 0.1155\n",
      "Epoch 49/50\n",
      "50000/50000 [==============================] - 3s 55us/step - loss: 2.2856 - acc: 0.8283 - val_loss: 2.2945 - val_acc: 0.1159\n",
      "Epoch 50/50\n",
      "50000/50000 [==============================] - 3s 61us/step - loss: 2.2855 - acc: 0.8292 - val_loss: 2.2942 - val_acc: 0.1153\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f2be50fd4e0>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Model \n",
    "# For batch size make sure it's big enough to leverage GPU parallelism \n",
    "#    As a basic rule you need to reduce learning rate as you reduce batch size\n",
    "# NOTE: Will be interesting to optimize for training time a bit by tweaking batch size\n",
    "model.fit(\n",
    "        X_train, \n",
    "        y_train_softmax_output,\n",
    "        batch_size=config.batch_size,\n",
    "        # Will use a fixed set of epochs rather than a stopping criteria \n",
    "        epochs=config.epochs,\n",
    "        validation_data=(X_test, y_test),\n",
    "        # Make sure we can see the pastries while they bake\n",
    "        callbacks=[WandbCallback(data_type=\"image\", labels=class_names)]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
